{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9c3f9467",
   "metadata": {},
   "source": [
    "# Semantic Similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35f9065d",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "This tutorial is available as an IPython notebook at [Malaya/example/similarity-semantic](https://github.com/huseinzol05/Malaya/tree/master/example/similarity-semantic).\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be978366",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "This module trained on both standard and local (included social media) language structures, so it is save to use for both.\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3bffd739",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d997083d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "logging.basicConfig(level=logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3fc1e6c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.8/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n",
      "  warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32\n",
      "CPU times: user 3.19 s, sys: 2.59 s, total: 5.79 s\n",
      "Wall time: 2.76 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/dev/malaya/malaya/tokenizer.py:214: FutureWarning: Possible nested set at position 3397\n",
      "  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n",
      "/home/husein/dev/malaya/malaya/tokenizer.py:214: FutureWarning: Possible nested set at position 3927\n",
      "  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "import malaya"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "15ea6b17",
   "metadata": {},
   "outputs": [],
   "source": [
    "string1 = 'Pemuda mogok lapar desak kerajaan prihatin isu iklim'\n",
    "string2 = 'Perbincangan isu pembalakan perlu babit kerajaan negeri'\n",
    "string3 = 'kerajaan perlu kisah isu iklim, pemuda mogok lapar'\n",
    "string4 = 'Kerajaan dicadang tubuh jawatankuasa khas tangani isu alam sekitar'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7dc663d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "news1 = 'Tun Dr Mahathir Mohamad mengakui pembubaran Parlimen bagi membolehkan pilihan raya diadakan tidak sesuai dilaksanakan pada masa ini berikutan isu COVID-19'\n",
    "tweet1 = 'DrM sembang pilihan raya tak boleh buat sebab COVID 19'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9d05c4c",
   "metadata": {},
   "source": [
    "### List available HuggingFace models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "742c6446",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'mesolitica/finetune-mnli-nanot5-small': {'Size (MB)': 148,\n",
       "  'macro precision': 0.87125,\n",
       "  'macro recall': 0.87131,\n",
       "  'macro f1-score': 0.87127},\n",
       " 'mesolitica/finetune-mnli-nanot5-base': {'Size (MB)': 892,\n",
       "  'macro precision': 0.78903,\n",
       "  'macro recall': 0.79064,\n",
       "  'macro f1-score': 0.78918}}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "malaya.similarity.semantic.available_huggingface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "50885ed2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tested on matched dev set translated MNLI, https://huggingface.co/datasets/mesolitica/translated-MNLI\n"
     ]
    }
   ],
   "source": [
    "print(malaya.similarity.semantic.info)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "419d30dc",
   "metadata": {},
   "source": [
    "### Load HuggingFace model\n",
    "\n",
    "```python\n",
    "def huggingface(\n",
    "    model: str = 'mesolitica/finetune-mnli-t5-small-standard-bahasa-cased',\n",
    "    force_check: bool = True,\n",
    "    **kwargs,\n",
    "):\n",
    "    \"\"\"\n",
    "    Load HuggingFace model to calculate semantic similarity between 2 sentences.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    model: str, optional (default='mesolitica/finetune-mnli-t5-small-standard-bahasa-cased')\n",
    "        Check available models at `malaya.similarity.semantic.available_huggingface()`.\n",
    "    force_check: bool, optional (default=True)\n",
    "        Force check model one of malaya model.\n",
    "        Set to False if you have your own huggingface model.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    result: malaya.torch_model.huggingface.Similarity\n",
    "    \"\"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5fdd641e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "model = malaya.similarity.semantic.huggingface()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96ab9368",
   "metadata": {},
   "source": [
    "#### predict batch of strings with probability\n",
    "\n",
    "```python\n",
    "def predict_proba(self, strings_left: List[str], strings_right: List[str]):\n",
    "        \"\"\"\n",
    "        calculate similarity for two different batch of texts.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        strings_left : List[str]\n",
    "        strings_right : List[str]\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        list: List[float]\n",
    "        \"\"\"\n",
    "```\n",
    "\n",
    "you need to give list of left strings, and list of right strings.\n",
    "\n",
    "first left string will compare will first right string and so on.\n",
    "\n",
    "similarity model only supported `predict_proba`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8650cf14",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([0.9980286 , 0.02898298, 0.79173875, 0.9694676 ], dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba([string1, string2, news1, news1], [string3, string4, tweet1, string1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a93f8d1",
   "metadata": {},
   "source": [
    "### able to infer for mixed MS and EN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "94d92d9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "en = 'Youth on hunger strike urge the government to be concerned about the climate issue'\n",
    "en2 = 'the end of the wrld, global warming!'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3fa1f404",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.00405155, 0.7731996 ], dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba([string1, string1], [en2, en])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
